from typing import Tuple, List
from enum import Enum
import threading
from queue import Queue
import time
import os
import numpy as np
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import onnxruntime
from pycocotools import mask
from groundingdino.util.inference import load_model, predict, annotate
import groundingdino.datasets.transforms as T
import torch
from PIL import Image


sam_checkpoint = os.environ["HOME"] + "/Models/sam_vit_h_4b8939.pth"
sam_onnx = os.environ["HOME"] + "/Models/sam_vit_h_4b8939.onnx"
sam_model_type = "vit_h"
groundingdino_model = load_model(
    os.environ["HOME"] + "/Models/groundingdino_swint_ogc.py",
    os.environ["HOME"] + "/Models/groundingdino_swint_ogc.pth")
groundingdino_box_threshold = 0.35
groundingdino_text_threshold = 0.25
device = "cuda"

sam_ort_session = onnxruntime.InferenceSession(sam_onnx)
sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)


g_sam_embeddings = dict()
g_cached_images = dict()


class AIGCTaskDef(Enum):
    SAM_Backbone          = 1
    SAM_Point             = 2
    SAM_Box_Points_Labels = 3
    CacheImage            = 4
    GroundingDINO         = 5
    GroundedSAM           = 6
    ChatGLM               = 7


class AIGCInputDef(Enum):
    Image             = 1
    Embedding         = 2
    Text              = 3
    Point             = 4
    Box_Points_Labels = 5


class AIGCEmbedding:
    img_id = 0
    img_h = 0
    img_w = 0
    sam_embedding = None


def groundingdino_input(image: np.ndarray) -> Tuple[np.array, torch.Tensor]:
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image_source = Image.fromarray(image)  # RGB image
    image_transformed, _ = transform(image_source, None)
    return image, image_transformed


def sam_backbone(image: np.ndarray) -> AIGCEmbedding:
    sam_predictor.set_image(image)  # RGB(uint8) Image
    img_embedding = sam_predictor.get_image_embedding().cpu().numpy()
    embedding = AIGCEmbedding()
    embedding.sam_embedding = img_embedding
    embedding.img_h = image.shape[0]
    # print(embedding.img_h)
    embedding.img_w = image.shape[1]
    # print(embedding.img_w)
    return embedding


def sam_point(sam_embedding: np.ndarray, px: float, py: float, img_w: int, img_h: int):
    input_point = np.array([[px * img_w, py * img_h]])
    input_label = np.array([1])

    onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
    onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(
        np.float32)
    onnx_coord = sam_predictor.transform.apply_coords(onnx_coord, (img_h, img_w)).astype(
        np.float32)

    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    ort_inputs = {
        "image_embeddings": sam_embedding,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        "orig_im_size": np.array((img_h, img_w), dtype=np.float32)
    }

    masks, _, low_res_logits = sam_ort_session.run(None, ort_inputs)
    masks = masks > sam_predictor.model.mask_threshold
    masks = masks.squeeze()

    enc_mask = mask.encode(masks.copy(order='F'))
    return enc_mask['counts']


def sam_box_points_labels(sam_embedding: np.ndarray,
                          box: np.ndarray, pts: np.ndarray, labels: np.ndarray, img_w: int, img_h: int):
    if len(box) == 4:
        box[0::2] *= img_w
        box[1::2] *= img_h
        onnx_box_coords = box.reshape(2, 2)
        onnx_box_labels = np.array([2, 3])
    else:
        onnx_box_coords = np.array([[0.0, 0.0]])
        onnx_box_labels = np.array([-1])

    pts[:, 0] *= img_w
    pts[:, 1] *= img_h
    onnx_coord = np.concatenate([pts, onnx_box_coords], axis=0)[None, :, :]
    onnx_label = np.concatenate([labels, onnx_box_labels], axis=0)[None, :].astype(
        np.float32)
    onnx_coord = sam_predictor.transform.apply_coords(onnx_coord, (img_h, img_w)).astype(
        np.float32)
    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    ort_inputs = {
        "image_embeddings": sam_embedding,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        "orig_im_size": np.array((img_h, img_w), dtype=np.float32)
    }

    masks, _, _ = sam_ort_session.run(None, ort_inputs)
    masks = masks > sam_predictor.model.mask_threshold
    masks = masks.squeeze()
    enc_mask = mask.encode(masks.copy(order='F'))
    return enc_mask['counts']


class AIGCTask:
    uid            = None
    req_id         = None
    task           = None
    image          = None
    # embedding    = None
    text           = None
    point          = None
    points         = None
    labels         = None
    box            = None
    results        = None

    def __init__(self, uid: str, task: AIGCTaskDef) -> None:  # User Identity
        self.uid = uid
        self.task = task

    def run(self) -> bool:
        done = True
        try:
            if   self.task == AIGCTaskDef.SAM_Backbone:
                assert self.image is not None
                # RGB(uint8) Image
                g_sam_embeddings[self.uid] = sam_backbone(self.image)
            elif self.task == AIGCTaskDef.SAM_Point:
                # assert self.embedding is not None
                assert self.point is not None
                if self.uid in g_sam_embeddings:
                    mask_enc = sam_point(
                        g_sam_embeddings[self.uid].sam_embedding,
                        self.point[0],
                        self.point[1],
                        g_sam_embeddings[self.uid].img_w,
                        g_sam_embeddings[self.uid].img_h
                    )
                    self.results = mask_enc
                else:
                    done = False
            elif self.task == AIGCTaskDef.SAM_Box_Points_Labels:
                # assert self.embedding is not None
                assert self.box is not None
                assert self.points is not None
                assert self.labels is not None
                if self.uid in g_sam_embeddings:
                    mask_enc = sam_box_points_labels(
                        g_sam_embeddings[self.uid].sam_embedding,
                        self.box,
                        self.points,
                        self.labels,
                        g_sam_embeddings[self.uid].img_w,
                        g_sam_embeddings[self.uid].img_h
                    )
                    self.results = mask_enc
                else:
                    done = False
            elif self.task == AIGCTaskDef.CacheImage:
                assert self.image is not None
                # RGB(uint8) Image
                g_cached_images[self.uid] = self.image
            elif self.task == AIGCTaskDef.GroundingDINO:
                assert self.text is not None
                # print(self.text)
                if self.uid in g_cached_images:
                    image_source, image = groundingdino_input(g_cached_images[self.uid])
                    boxes, logits, phrases = predict(
                        model=groundingdino_model,
                        image=image,
                        caption=self.text,
                        box_threshold=groundingdino_box_threshold,
                        text_threshold=groundingdino_text_threshold
                    )
                    self.results = {'boxes': boxes.cpu().numpy(), 'logits': logits.cpu().numpy(), 'phrases': phrases}
                    # print(self.results)
                else:
                    done = False

        except Exception as e:
            done = False

        return done

    def set_req_id(self, id_: int) -> None:
        self.req_id = id_

    def set_image(self, image: np.ndarray) -> None:
        self.image = image

    def set_text(self, text: str) -> None:
        self.text = text

    def set_point(self, point: np.ndarray) -> None:
        self.point = point

    def set_box_points_labels(self, box: np.ndarray, points: np.ndarray, labels: np.ndarray) -> None:
        self.box = box
        self.points = points
        self.labels = labels


class AIGCSocket:
    def __init__(self):
        pass

    def response_info(self, task: AIGCTask):
        pass

    def response_error(self, task: AIGCTask):
        pass


class AIGCTaskManager(threading.Thread):
    task_queue = Queue()
    running = False

    def __init__(self, socket: AIGCSocket):
        threading.Thread.__init__(self)
        self.socket = socket
        self.running = True

    def run(self):
        while self.running:
            if not self.task_queue.empty():
                task = self.task_queue.get()
                if task.run():
                    self.socket.response_info(task)
                else:
                    self.socket.response_error(task)
            else:
                time.sleep(0.1)

    def add_task(self, task: AIGCTask):
        self.task_queue.put(task)


if __name__ == '__main__':
    t = AIGCTask('qwer', AIGCTaskDef.SAM_Backbone)
    print(t.uid)
    print(t.task)
    print(t.box)
    print(t.run())
